# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for utils."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile

from absl.testing import absltest
from absl.testing import parameterized

import util
import numpy as np
import tensorflow as tf


class UtilTest(tf.test.TestCase, parameterized.TestCase):

  def setUp(self):
    self.temp_dir = tempfile.mkdtemp(dir=absltest.get_default_test_tmpdir())

  def tearDown(self):
    tf.gfile.DeleteRecursively(self.temp_dir)

  def _make_model(self, batch_size, num_batches, variable_initializer_value):
    np_inputs = np.arange(batch_size * num_batches)
    np_inputs = np.float32(np_inputs)
    inputs = tf.data.Dataset.from_tensor_slices(np_inputs)
    inputs = inputs.batch(batch_size).make_one_shot_iterator().get_next()
    scale = tf.get_variable(
        name='scale', dtype=tf.float32, initializer=variable_initializer_value,
        trainable=True)
    output = inputs * scale
    return output

  def test_run_graph_and_process_results(self):

    batch_size = 3
    num_batches = 5

    # Make a graph that contains a Variable and save it to checkpoint.
    with tf.Graph().as_default():
      _ = self._make_model(
          batch_size=batch_size, num_batches=num_batches,
          variable_initializer_value=2.0)
      saver = tf.train.Saver(tf.trainable_variables())
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, self.temp_dir + '/model-')

    # Make another copy of the graph, and process data using this one.
    with tf.Graph().as_default():
      # We intentionally make this graph have a different value for its Variable
      # than the graph above. When we restore from checkpoint, we will grab the
      # value from the first graph. This helps test that the Variables are
      # being properly restored from checkpoint.
      ops_to_fetch = self._make_model(
          batch_size=batch_size, num_batches=num_batches,
          variable_initializer_value=3.0
      )

      results = []
      def process_fetched_values_fn(np_array):
        results.append(np_array)

      model_checkpoint_path = self.temp_dir

      util.run_graph_and_process_results(ops_to_fetch, model_checkpoint_path,
                                         process_fetched_values_fn)

      results = np.concatenate(results, axis=0)
      expected_results = np.arange(num_batches * batch_size) * 2.0

      self.assertAllEqual(results, expected_results)

  @parameterized.parameters(('7'), ('10'), ('65'))
  def test_map_predictor(self, sub_batch_size):
    input_op = {
        'a': tf.random_normal(shape=(50, 5)),
        'b': tf.random_normal(shape=(50, 5))
    }

    def predictor_fn(data):
      return data['a'] + data['b']

    mapped_prediction = util.map_predictor(
        input_op, predictor_fn, sub_batch_size=sub_batch_size)
    unmapped_prediction = predictor_fn(input_op)
    difference = tf.reduce_mean(
        tf.squared_difference(mapped_prediction, unmapped_prediction))
    with tf.Session() as sess:
      self.assertLess(
          sess.run(difference), 1e-6,
          'The output of _map_predictor does not match a direct '
          'application of predictor_fn.')

  def test_value_op_with_initializer(self):
    """Test correctness of library_matching.value_op_with_initializer."""

    base_value_op = tf.get_variable('value', initializer=0.)

    def make_value_op():
      return base_value_op

    def make_init_op(value):
      # This is a simple assignment that could have been achieved by changing
      # the initializer above. However, in other use cases of
      # value_op_with_initializer, the contructed value requires
      # data-dependent computation that can't be done via an initializer.
      return value.assign(tf.ones_like(value))

    value_op = util.value_op_with_initializer(make_value_op, make_init_op)

    # Check that the value of the Variable generated by make_value_op()
    # is the value constructed by make_init_op, not the value given
    # the initializer given to the Variable's constructor.
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      self.assertAllEqual(sess.run(base_value_op), 0.0)
      self.assertAllEqual(sess.run(value_op), 1.0)
      self.assertAllEqual(sess.run(base_value_op), 1.0)

  def test_scatter_by_anchor_indices(self):

    def _validate(anchor_indices, data, index_shift, expected_output):
      with tf.Graph().as_default():
        output = util.scatter_by_anchor_indices(anchor_indices, data,
                                                index_shift)
        with tf.Session() as sess:
          actual_output = sess.run(output)
      self.assertAllClose(
          np.array(expected_output, dtype=np.float32), actual_output)

    data = [[1, 2, 3], [4, 5, 6]]

    anchor_indices = [1, 1]
    index_shift = 0
    expected_output = [[2, 1, 0], [5, 4, 0]]
    _validate(anchor_indices, data, index_shift, expected_output)

    anchor_indices = [2, 2]
    index_shift = 0
    expected_output = [[3, 2, 1], [6, 5, 4]]
    _validate(anchor_indices, data, index_shift, expected_output)

    anchor_indices = [1, 1]
    index_shift = 1
    expected_output = [[3, 2, 1], [6, 5, 4]]
    _validate(anchor_indices, data, index_shift, expected_output)

    anchor_indices = [0, 1]
    index_shift = 1
    expected_output = [[2, 1, 0], [6, 5, 4]]
    _validate(anchor_indices, data, index_shift, expected_output)

    data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
    anchor_indices = [1, 2, 3]
    index_shift = 0
    expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]]
    _validate(anchor_indices, data, index_shift, expected_output)

    data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
    anchor_indices = [0, 1, 2]
    index_shift = 1
    expected_output = [[2, 1, 0, 0], [7, 6, 5, 0], [12, 11, 10, 9]]
    _validate(anchor_indices, data, index_shift, expected_output)

if __name__ == '__main__':
  tf.test.main()
